#!/usr/bin/env python2

"""
    psud
    PSU information update daemon for SONiC
    This daemon will loop to collect PSU related information and then write the information to state DB.
    Currently it is implemented based on old plugins rather than new platform APIs. So the PSU information just
    includes three things: number of PSU, PSU presence and PSU status which is supported by old plugins.
    The loop interval is PSU_INFO_UPDATE_PERIOD_SECS in seconds.
"""

try:
    import sys
    import time
    import signal
    import threading
    from swsscommon import swsscommon
    from sonic_daemon_base import daemon_base
    from sonic_daemon_base.daemon_base import Logger
    from sonic_daemon_base.daemon_base import DaemonBase
except ImportError, e:
    raise ImportError (str(e) + " - required module not found")

#
# Constants ====================================================================
#

SYSLOG_IDENTIFIER = "psud"

PLATFORM_SPECIFIC_MODULE_NAME = "psuutil"
PLATFORM_SPECIFIC_CLASS_NAME = "PsuUtil"

CHASSIS_INFO_TABLE = 'CHASSIS_INFO'
CHASSIS_INFO_KEY_TEMPLATE = 'chassis {}'
CHASSIS_INFO_PSU_NUM_FIELD = 'psu_num'

PSU_INFO_TABLE = 'PSU_INFO'
PSU_INFO_KEY_TEMPLATE = 'PSU {}'
PSU_INFO_PRESENCE_FIELD = 'presence'
PSU_INFO_STATUS_FIELD = 'status'
PSU_INFO_TEMP_FIELD = 'temp'
PSU_INFO_TEMP_TH_FIELD = 'temp_threshold'
PSU_INFO_VOLTAGE_FIELD = 'voltage'
PSU_INFO_VOLTAGE_MAX_TH_FIELD = 'voltage_max_threshold'
PSU_INFO_VOLTAGE_MIN_TH_FIELD = 'voltage_min_threshold'

PSU_INFO_UPDATE_PERIOD_SECS = 3

PSUUTIL_LOAD_ERROR = 1

logger = Logger(SYSLOG_IDENTIFIER)

platform_psuutil = None
platform_chassis = None

# temporary wrappers that are compliable with both new platform api and old-style plugin mode
def _wrapper_get_num_psus():
    if platform_chassis is not None:
        try:
            return platform_chassis.get_num_psus()
        except NotImplementedError:
            pass
    return platform_psuutil.get_num_psus()

def _wrapper_get_psus_presence(psu_index):
    if platform_chassis is not None:
        try:
            return platform_chassis.get_psu(psu_index - 1).get_presence()
        except NotImplementedError:
            pass
    return platform_psuutil.get_psu_presence(psu_index)

def _wrapper_get_psus_status(psu_index):
    if platform_chassis is not None:
        try:
            return platform_chassis.get_psu(psu_index - 1).get_powergood_status()
        except NotImplementedError:
            pass
    return platform_psuutil.get_psu_status(psu_index)


#
# Helper functions =============================================================
#

def psu_db_update(psu_tbl, psu_num):
    for psu_index in range(1, psu_num + 1):
        fvs = swsscommon.FieldValuePairs([(PSU_INFO_PRESENCE_FIELD,
                                           'true' if _wrapper_get_psus_presence(psu_index) else 'false'),
                                          (PSU_INFO_STATUS_FIELD,
                                           'true' if _wrapper_get_psus_status(psu_index) else 'false')])
        psu_tbl.set(PSU_INFO_KEY_TEMPLATE.format(psu_index), fvs)


# try get information from platform API and return a default value if caught NotImplementedError
def try_get(callback, default=None):
    """
    Handy function to invoke the callback and catch NotImplementedError
    :param callback: Callback to be invoked
    :param default: Default return value if exception occur
    :return: Default return value if exception occur else return value of the callback
    """
    try:
        ret = callback()
        if ret is None:
            ret = default
    except NotImplementedError:
        ret = default

    return ret

def log_on_status_changed(normal_status, normal_log, abnormal_log):
    """
    Log when any status changed
    :param normal_status: Expected status.
    :param normal_log: Log string for expected status.
    :param abnormal_log: Log string for unexpected status
    :return:
    """
    if normal_status:
        logger.log_notice(normal_log)
    else:
        logger.log_warning(abnormal_log)

#
# PSU status ===================================================================
#

class PsuStatus(object):
    def __init__(self, psu):
        self.psu = psu
        self.presence = True
        self.power_good = True
        self.voltage_good = True
        self.temperature_good = True

    def set_presence(self, presence):
        """
        Set and cache PSU presence status
        :param presence: PSU presence status
        :return: True if status changed else False
        """
        if presence == self.presence:
            return False

        self.presence = presence
        return True

    def set_power_good(self, power_good):
        """
        Set and cache PSU power good status
        :param power_good: PSU power good status
        :return: True if status changed else False
        """
        if power_good == self.power_good:
            return False

        self.power_good = power_good
        return True

    def set_voltage(self, voltage, high_threshold, low_threshold):
        if not voltage or not high_threshold or not low_threshold:
            if self.voltage_good is not True:
                logger.log_warning('PSU voltage or high_threshold or low_threshold become unavailable, '
                                   'voltage={}, high_threshold={}, low_threshold={}'.format(voltage, high_threshold, low_threshold))
                self.voltage_good = True
            return False

        voltage_good = (low_threshold <= voltage <= high_threshold)
        if voltage_good == self.voltage_good:
            return False

        self.voltage_good = voltage_good
        return True

    def set_temperature(self, temperature, high_threshold):
        if not temperature or not high_threshold:
            if self.temperature_good is not True:
                logger.log_warning('PSU temperature or high_threshold become unavailable, '
                                   'temperature={}, high_threshold={}'.format(temperature, high_threshold))
                self.temperature_good = True
            return False

        temperature_good = (temperature < high_threshold)
        if temperature_good == self.temperature_good:
            return False

        self.temperature_good = temperature_good
        return True

    def is_ok(self):
        return self.presence and self.power_good and self.voltage_good and self.temperature_good
        
#
# Daemon =======================================================================
#

class DaemonPsud(DaemonBase):
    def __init__(self):
        DaemonBase.__init__(self)

        self.stop = threading.Event()
        self.psu_status_dict = {}

    # Signal handler
    def signal_handler(self, sig, frame):
        if sig == signal.SIGHUP:
            logger.log_info("Caught SIGHUP - ignoring...")
        elif sig == signal.SIGINT:
            logger.log_info("Caught SIGINT - exiting...")
            self.stop.set()
        elif sig == signal.SIGTERM:
            logger.log_info("Caught SIGTERM - exiting...")
            self.stop.set()
        else:
            logger.log_warning("Caught unhandled signal '" + sig + "'")

    # Run daemon
    def run(self):
        global platform_psuutil
        global platform_chassis

        logger.log_info("Starting up...")

        # Load new platform api class
        try:
            import sonic_platform.platform
            platform_chassis = sonic_platform.platform.Platform().get_chassis()
        except Exception as e:
            logger.log_warning("Failed to load chassis due to {}".format(repr(e)))

        # Load platform-specific psuutil class
        if platform_chassis is None:
            try:
                platform_psuutil = self.load_platform_util(PLATFORM_SPECIFIC_MODULE_NAME, PLATFORM_SPECIFIC_CLASS_NAME)
            except Exception as e:
                logger.log_error("Failed to load psuutil: %s" % (str(e)), True)
                sys.exit(PSUUTIL_LOAD_ERROR)

        # Connect to STATE_DB and create psu/chassis info tables
        state_db = daemon_base.db_connect("STATE_DB")
        chassis_tbl = swsscommon.Table(state_db, CHASSIS_INFO_TABLE)
        psu_tbl = swsscommon.Table(state_db, PSU_INFO_TABLE)

        # Post psu number info to STATE_DB
        psu_num = _wrapper_get_num_psus()
        fvs = swsscommon.FieldValuePairs([(CHASSIS_INFO_PSU_NUM_FIELD, str(psu_num))])
        chassis_tbl.set(CHASSIS_INFO_KEY_TEMPLATE.format(1), fvs)

        # Start main loop
        logger.log_info("Start daemon main loop")

        while not self.stop.wait(PSU_INFO_UPDATE_PERIOD_SECS):
            psu_db_update(psu_tbl, psu_num)
            self.update_psu_data(psu_tbl)
            self._update_led_color(psu_tbl)

        logger.log_info("Stop daemon main loop")

        # Delete all the information from DB and then exit
        for psu_index in range(1, psu_num + 1):
            psu_tbl._del(PSU_INFO_KEY_TEMPLATE.format(psu_index))

        chassis_tbl._del(CHASSIS_INFO_KEY_TEMPLATE.format(1))

        logger.log_info("Shutting down...")

    def update_psu_data(self, psu_tbl):
        if not platform_chassis:
            return

        for index, psu in enumerate(platform_chassis.get_all_psus()):
            try:
                self._update_single_psu_data(index + 1, psu, psu_tbl)
            except Exception as e:
                logger.log_warning("Failed to update PSU data - {}".format(e))
    
    def _update_single_psu_data(self, index, psu, psu_tbl):
        name = try_get(psu.get_name)
        if not name:
            name = PSU_INFO_KEY_TEMPLATE.format(index)
        presence = _wrapper_get_psus_presence(index)
        power_good = False
        voltage = None
        voltage_high_threshold = None
        voltage_low_threshold = None
        temperature = None
        temperature_threshold = None
        if presence:
            power_good = _wrapper_get_psus_status(index)
            voltage = try_get(psu.get_voltage)
            voltage_high_threshold = try_get(psu.get_voltage_high_threshold)
            voltage_low_threshold = try_get(psu.get_voltage_low_threshold)
            temperature = try_get(psu.get_temperature)
            temperature_threshold = try_get(psu.get_temperature_high_threshold)
    
        if index not in self.psu_status_dict:
            self.psu_status_dict[index] = PsuStatus(psu)
        
        psu_status = self.psu_status_dict[index]
        set_led = False
        if psu_status.set_presence(presence):
            set_led = True
            log_on_status_changed(psu_status.presence, 
                                    'PSU absence warning cleared: {} is inserted back.'.format(name),
                                    'PSU absence warning: {} is not present.'.format(name)
            )

        if presence and psu_status.set_power_good(power_good):
            set_led = True
            log_on_status_changed(psu_status.power_good, 
                                    'Power absence warning cleared: {} power is back to normal.'.format(name),
                                    'Power absence warning: {} is out of power.'.format(name)
            )

        if presence and psu_status.set_voltage(voltage, voltage_high_threshold, voltage_low_threshold):
            set_led = True
            log_on_status_changed(psu_status.voltage_good, 
                                    'PSU voltage warning cleared: {} voltage is back to normal.'.format(name),
                                    'PSU voltage warning: {} voltage out of range, current voltage={}, valid range=[{}, {}].'.format(name, voltage, voltage_high_threshold, voltage_low_threshold)
            )

        if presence and psu_status.set_temperature(temperature, temperature_threshold):
            set_led = True
            log_on_status_changed(psu_status.temperature_good, 
                                    'PSU temperature warning cleared: {} temperature is back to normal.'.format(name),
                                    'PSU temperature warning: {} temperature too hot, temperature={}, threshold={}.'.format(name, temperature, temperature_threshold)
            )
        
        if set_led:
            self._set_psu_led(psu, psu_status)

        fvs = swsscommon.FieldValuePairs(
            [(PSU_INFO_TEMP_FIELD, str(temperature)),
             (PSU_INFO_TEMP_TH_FIELD, str(temperature_threshold)),
             (PSU_INFO_VOLTAGE_FIELD, str(voltage)),
             (PSU_INFO_VOLTAGE_MIN_TH_FIELD, str(voltage_low_threshold)),
             (PSU_INFO_VOLTAGE_MAX_TH_FIELD, str(voltage_high_threshold)),
            ])
        psu_tbl.set(PSU_INFO_KEY_TEMPLATE.format(index), fvs)
        

    def _set_psu_led(self, psu, psu_status):
        try:
            color = psu.STATUS_LED_COLOR_GREEN if psu_status.is_ok() else psu.STATUS_LED_COLOR_RED
            psu.set_status_led(color)
        except NotImplementedError as e:
            pass

    def _update_led_color(self, psu_tbl):
        if not platform_chassis:
            return

        for index, psu_status in self.psu_status_dict.items():
            try:
                fvs = swsscommon.FieldValuePairs([
                    ('led_status', str(try_get(psu_status.psu.get_status_led)))
                ])
            except Exception as e:
                logger.log_warning('Failed to get led status for psu {}'.format(index))
                fvs = swsscommon.FieldValuePairs([
                    ('led_status', NOT_AVAILABLE)
                ])
            psu_tbl.set(PSU_INFO_KEY_TEMPLATE.format(index), fvs)
#
# Main =========================================================================
#

def main():
    psud = DaemonPsud()
    psud.run()

if __name__ == '__main__':
    main()
